{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MindSpore数据格式转换"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概述\n",
    "\n",
    "用户可以将非标准的数据集和常用的数据集转换为MindSpore数据格式，即MindRecord，从而方便地加载到MindSpore中进行训练。同时，MindSpore在部分场景做了性能优化，使用MindRecord可以获得更好的性能。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 非标准数据集转换MindRecord\n",
    "\n",
    "下面主要介绍如何将CV类数据和NLP类数据转换为MindRecord，并通过`MindDataset`实现MindRecord文件的读取。\n",
    "\n",
    "### 转换CV类数据集\n",
    "\n",
    "本示例主要介绍用户如何将自己的CV类数据集转换成MindRecord，并使用`MindDataset`读取。\n",
    "\n",
    "本示例首先创建一个包含100条记录的MindRecord文件，其样本包含`file_name`（字符串）、\n",
    "`label`（整形）、 `data`（二进制）三个字段，然后使用`MindDataset`读取该MindRecord文件。\n",
    "\n",
    "1. 导入相关模块。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "    from io import BytesIO\n",
    "    import os\n",
    "    import mindspore.dataset as ds\n",
    "    from mindspore.mindrecord import FileWriter\n",
    "    import mindspore.dataset.vision.c_transforms as vision\n",
    "    from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 生成100张图像，并转换成MindRecord。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MSRStatus.SUCCESS"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MINDRECORD_FILE = \"test.mindrecord\"\n",
    "\n",
    "if os.path.exists(MINDRECORD_FILE):\n",
    "    os.remove(MINDRECORD_FILE)\n",
    "    os.remove(MINDRECORD_FILE + \".db\")\n",
    "\n",
    "writer = FileWriter(file_name=MINDRECORD_FILE, shard_num=1)\n",
    "\n",
    "cv_schema = {\"file_name\": {\"type\": \"string\"}, \"label\": {\"type\": \"int32\"}, \"data\": {\"type\": \"bytes\"}}\n",
    "writer.add_schema(cv_schema, \"it is a cv dataset\")\n",
    "\n",
    "writer.add_index([\"file_name\", \"label\"])\n",
    "\n",
    "data = []\n",
    "for i in range(100):\n",
    "    i += 1\n",
    "\n",
    "    sample = {}\n",
    "    white_io = BytesIO()\n",
    "    Image.new('RGB', (i*10, i*10), (255, 255, 255)).save(white_io, 'JPEG')\n",
    "    image_bytes = white_io.getvalue()\n",
    "    sample['file_name'] = str(i) + \".jpg\"\n",
    "    sample['label'] = i\n",
    "    sample['data'] = white_io.getvalue()\n",
    "\n",
    "    data.append(sample)\n",
    "    if i % 10 == 0:\n",
    "        writer.write_raw_data(data)\n",
    "        data = []\n",
    "\n",
    "if data:\n",
    "    writer.write_raw_data(data)\n",
    "\n",
    "writer.commit()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**参数说明：**\n",
    "- `MINDRECORD_FILE`：输出的MindRecord文件路径。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3. 通过`MindDataset`读取MindRecord。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Got 100 samples\n"
     ]
    }
   ],
   "source": [
    "data_set = ds.MindDataset(dataset_file=MINDRECORD_FILE)\n",
    "decode_op = vision.Decode()\n",
    "data_set = data_set.map(operations=decode_op, input_columns=[\"data\"], num_parallel_workers=2)\n",
    "count = 0\n",
    "for item in data_set.create_dict_iterator(output_numpy=True):\n",
    "    count += 1\n",
    "print(\"Got {} samples\".format(count))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### 转换NLP类数据集\n",
    "\n",
    "本示例主要介绍用户如何将自己的NLP类数据集转换成MindRecord，并使用`MindDataset`读取。为了方便展示，此处略去了将文本转换成字典序的预处理过程。\n",
    "\n",
    "本示例首先创建一个包含100条记录的MindRecord文件，其样本包含八个字段，均为整形数组，然后使用`MindDataset`读取该MindRecord文件。\n",
    "\n",
    "1. 导入相关模块。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import mindspore.dataset as ds\n",
    "from mindspore.mindrecord import FileWriter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 生成100条文本数据，并转换成MindRecord。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MSRStatus.SUCCESS"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MINDRECORD_FILE = \"test.mindrecord\"\n",
    "\n",
    "if os.path.exists(MINDRECORD_FILE):\n",
    "    os.remove(MINDRECORD_FILE)\n",
    "    os.remove(MINDRECORD_FILE + \".db\")\n",
    "\n",
    "writer = FileWriter(file_name=MINDRECORD_FILE, shard_num=1)\n",
    "\n",
    "nlp_schema = {\"source_sos_ids\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "              \"source_sos_mask\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "              \"source_eos_ids\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "              \"source_eos_mask\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "              \"target_sos_ids\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "              \"target_sos_mask\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "              \"target_eos_ids\": {\"type\": \"int64\", \"shape\": [-1]},\n",
    "              \"target_eos_mask\": {\"type\": \"int64\", \"shape\": [-1]}}\n",
    "writer.add_schema(nlp_schema, \"it is a preprocessed nlp dataset\")\n",
    "\n",
    "data = []\n",
    "for i in range(100):\n",
    "    i += 1\n",
    "\n",
    "    sample = {\"source_sos_ids\": np.array([i, i + 1, i + 2, i + 3, i + 4], dtype=np.int64),\n",
    "              \"source_sos_mask\": np.array([i * 1, i * 2, i * 3, i * 4, i * 5, i * 6, i * 7], dtype=np.int64),\n",
    "              \"source_eos_ids\": np.array([i + 5, i + 6, i + 7, i + 8, i + 9, i + 10], dtype=np.int64),\n",
    "              \"source_eos_mask\": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),\n",
    "              \"target_sos_ids\": np.array([28, 29, 30, 31, 32], dtype=np.int64),\n",
    "              \"target_sos_mask\": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),\n",
    "              \"target_eos_ids\": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),\n",
    "              \"target_eos_mask\": np.array([48, 49, 50, 51], dtype=np.int64)}\n",
    "\n",
    "    data.append(sample)\n",
    "    if i % 10 == 0:\n",
    "        writer.write_raw_data(data)\n",
    "        data = []\n",
    "\n",
    "if data:\n",
    "    writer.write_raw_data(data)\n",
    "\n",
    "writer.commit()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**参数说明：**\n",
    "- `MINDRECORD_FILE`：输出的MindRecord文件路径。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3. 通过`MindDataset`读取MindRecord。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Got 100 samples\n"
     ]
    }
   ],
   "source": [
    "data_set = ds.MindDataset(dataset_file=MINDRECORD_FILE)\n",
    "count = 0\n",
    "for item in data_set.create_dict_iterator():\n",
    "    count += 1\n",
    "print(\"Got {} samples\".format(count))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 常用数据集转换MindRecord\n",
    "\n",
    "MindSpore提供转换常用数据集的工具类，能够将常用的数据集转换为MindRecord。部分常用数据集及其对应的工具类列表如下。\n",
    "\n",
    "| 数据集 | 格式转换工具类 |\n",
    "| :------- | :----------- |\n",
    "| CIFAR-10 | Cifar10ToMR  |\n",
    "| ImageNet | ImageNetToMR |\n",
    "| TFRecord | TFRecordToMR |\n",
    "| CSV File | CsvToMR |\n",
    "\n",
    "更多数据集转换的详细说明可参见[API文档](https://www.mindspore.cn/doc/api_python/zh-CN/master/mindspore/mindspore.mindrecord.html)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 转换CIFAR-10数据集\n",
    "\n",
    "用户可以通过`Cifar10ToMR`类，将CIFAR-10原始数据转换为MindRecord，并使用`MindDataset`读取。\n",
    "\n",
    "下载CIFAR-10数据集并解压到指定目录，执行如下命令："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-12-09 16:16:04--  https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/cifar10-py.zip\n",
      "Resolving proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)... 192.168.0.172\n",
      "Connecting to proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)|192.168.0.172|:8083... connected.\n",
      "Proxy request sent, awaiting response... 200 OK\n",
      "Length: 170647614 (163M) [application/zip]\n",
      "Saving to: ‘cifar10-py.zip’\n",
      "\n",
      "cifar10-py.zip      100%[===================>] 162.74M   139MB/s    in 1.2s    \n",
      "\n",
      "2020-12-09 16:16:06 (139 MB/s) - ‘cifar10-py.zip’ saved [170647614/170647614]\n",
      "\n",
      "Archive:  cifar10-py.zip\n",
      "   creating: ./datasets/cifar-10-batches-py/\n",
      "  inflating: ./datasets/cifar-10-batches-py/batches.meta  \n",
      "  inflating: ./datasets/cifar-10-batches-py/data_batch_1  \n",
      "  inflating: ./datasets/cifar-10-batches-py/data_batch_2  \n",
      "  inflating: ./datasets/cifar-10-batches-py/data_batch_3  \n",
      "  inflating: ./datasets/cifar-10-batches-py/data_batch_4  \n",
      "  inflating: ./datasets/cifar-10-batches-py/data_batch_5  \n",
      "  inflating: ./datasets/cifar-10-batches-py/readme.html  \n",
      "  inflating: ./datasets/cifar-10-batches-py/test_batch  \n",
      "./datasets/cifar-10-batches-py/\n",
      "├── batches.meta\n",
      "├── data_batch_1\n",
      "├── data_batch_2\n",
      "├── data_batch_3\n",
      "├── data_batch_4\n",
      "├── data_batch_5\n",
      "├── readme.html\n",
      "└── test_batch\n",
      "\n",
      "0 directories, 8 files\n"
     ]
    }
   ],
   "source": [
    "!wget -N https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/cifar10-py.zip\n",
    "!unzip -o cifar10-py.zip -d ./datasets\n",
    "!tree ./datasets/cifar-10-batches-py/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import mindspore.dataset as ds\n",
    "import mindspore.dataset.vision.c_transforms as vision\n",
    "from mindspore.mindrecord import Cifar10ToMR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3. 创建`Cifar10ToMR`对象，调用`transform`接口，将CIFAR-10数据集转换为MindRecord。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MSRStatus.SUCCESS"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds_target_path = \"./datasets/mindspore_dataset_conversion/\"\n",
    "# clean old run files \n",
    "os.system(\"rm -f {}*\".format(ds_target_path))\n",
    "\n",
    "!mkdir -p $ds_target_path\n",
    "CIFAR10_DIR = \"./datasets/cifar-10-batches-py/\"\n",
    "MINDRECORD_FILE = \"./datasets/mindspore_dataset_conversion/cifar10.mindrecord\"\n",
    "cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)\n",
    "cifar10_transformer.transform(['label'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**参数说明：**\n",
    "- `CIFAR10_DIR`：CIFAR-10数据集路径。\n",
    "- `MINDRECORD_FILE`：输出的MindRecord文件路径。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4. 通过`MindDataset`读取MindRecord。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Got 50000 samples\n"
     ]
    }
   ],
   "source": [
    "data_set = ds.MindDataset(dataset_file=MINDRECORD_FILE)\n",
    "decode_op = vision.Decode()\n",
    "data_set = data_set.map(operations=decode_op, input_columns=[\"data\"], num_parallel_workers=2)\n",
    "count = 0\n",
    "for item in data_set.create_dict_iterator(output_numpy=True):\n",
    "    count += 1\n",
    "print(\"Got {} samples\".format(count))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 转换CSV数据集\n",
    "\n",
    "本示例首先创建一个包含5条记录的CSV文件，然后通过`CsvToMR`工具类将CSV文件转换为MindRecord，并最终通过`MindDataset`将其读取出来。\n",
    "\n",
    "1. 导入相关模块。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import os\n",
    "import mindspore.dataset as ds\n",
    "from mindspore.mindrecord import CsvToMR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 生成CSV文件，并转换成MindRecord。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "CSV_FILE = \"test.csv\"\n",
    "MINDRECORD_FILE = \"test.mindrecord\"\n",
    "\n",
    "def generate_csv():\n",
    "    headers = [\"id\", \"name\", \"math\", \"english\"]\n",
    "    rows = [(1, \"Lily\", 78.5, 90),\n",
    "            (2, \"Lucy\", 99, 85.2),\n",
    "            (3, \"Mike\", 65, 71),\n",
    "            (4, \"Tom\", 95, 99),\n",
    "            (5, \"Jeff\", 85, 78.5)]\n",
    "    with open(CSV_FILE, 'w', encoding='utf-8') as f:\n",
    "        writer = csv.writer(f)\n",
    "        writer.writerow(headers)\n",
    "        writer.writerows(rows)\n",
    "\n",
    "generate_csv()\n",
    "\n",
    "if os.path.exists(MINDRECORD_FILE):\n",
    "    os.remove(MINDRECORD_FILE)\n",
    "    os.remove(MINDRECORD_FILE + \".db\")\n",
    "\n",
    "csv_transformer = CsvToMR(CSV_FILE, MINDRECORD_FILE, partition_number=1)\n",
    "\n",
    "csv_transformer.transform()\n",
    "\n",
    "assert os.path.exists(MINDRECORD_FILE)\n",
    "assert os.path.exists(MINDRECORD_FILE + \".db\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**参数说明：**\n",
    "- `CSV_FILE`：CSV文件的路径。\n",
    "- `MINDRECORD_FILE`：输出的MindRecord文件路径。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3. 通过`MindDataset`读取MindRecord。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Got 5 samples\n"
     ]
    }
   ],
   "source": [
    "data_set = ds.MindDataset(dataset_file=MINDRECORD_FILE)\n",
    "count = 0\n",
    "for item in data_set.create_dict_iterator(output_numpy=True):\n",
    "    count += 1\n",
    "print(\"Got {} samples\".format(count))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 转换TFRecord数据集\n",
    "\n",
    "> 目前支持TensorFlow 1.13.0-rc1及以上版本。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此部分案例需提前安装TensorFlow，如果未安装，执行下面的命令进行安装，完成安装后，需要关闭此文档并重新打开执行。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorFlow installed\n"
     ]
    }
   ],
   "source": [
    "os.system('pip install tensorflow') if os.system('python -c \"import tensorflow\"') else print(\"TensorFlow installed\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本示例首先通过TensorFlow创建一个TFRecord文件，然后通过`TFRecordToMR`工具类将TFRecord文件转换为MindRecord，最后通过`MindDataset`将其读取出来，并使用`Decode`算子对`image_bytes`字段进行解码。\n",
    "\n",
    "1. 导入相关模块。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import collections\n",
    "from io import BytesIO\n",
    "import os\n",
    "import mindspore.dataset as ds\n",
    "from mindspore.mindrecord import TFRecordToMR\n",
    "import mindspore.dataset.vision.c_transforms as vision\n",
    "from PIL import Image\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 生成TFRecord文件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Write 10 rows in tfrecord.\n"
     ]
    }
   ],
   "source": [
    "TFRECORD_FILE = \"test.tfrecord\"\n",
    "MINDRECORD_FILE = \"test.mindrecord\"\n",
    "\n",
    "def generate_tfrecord():\n",
    "    def create_int_feature(values):\n",
    "        if isinstance(values, list):\n",
    "            feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))\n",
    "        else:\n",
    "            feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[values]))\n",
    "        return feature\n",
    "\n",
    "    def create_float_feature(values):\n",
    "        if isinstance(values, list):\n",
    "            feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))\n",
    "        else:\n",
    "            feature = tf.train.Feature(float_list=tf.train.FloatList(value=[values]))\n",
    "        return feature\n",
    "\n",
    "    def create_bytes_feature(values):\n",
    "        if isinstance(values, bytes):\n",
    "            white_io = BytesIO()\n",
    "            Image.new('RGB', (10, 10), (255, 255, 255)).save(white_io, 'JPEG')\n",
    "            image_bytes = white_io.getvalue()\n",
    "            feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))\n",
    "        else:\n",
    "            feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')]))\n",
    "        return feature\n",
    "\n",
    "    writer = tf.io.TFRecordWriter(TFRECORD_FILE)\n",
    "\n",
    "    example_count = 0\n",
    "    for i in range(10):\n",
    "        file_name = \"000\" + str(i) + \".jpg\"\n",
    "        image_bytes = bytes(str(\"aaaabbbbcccc\" + str(i)), encoding=\"utf-8\")\n",
    "        int64_scalar = i\n",
    "        float_scalar = float(i)\n",
    "        int64_list = [i, i+1, i+2, i+3, i+4, i+1234567890]\n",
    "        float_list = [float(i), float(i+1), float(i+2.8), float(i+3.2),\n",
    "                    float(i+4.4), float(i+123456.9), float(i+98765432.1)]\n",
    "\n",
    "        features = collections.OrderedDict()\n",
    "        features[\"file_name\"] = create_bytes_feature(file_name)\n",
    "        features[\"image_bytes\"] = create_bytes_feature(image_bytes)\n",
    "        features[\"int64_scalar\"] = create_int_feature(int64_scalar)\n",
    "        features[\"float_scalar\"] = create_float_feature(float_scalar)\n",
    "        features[\"int64_list\"] = create_int_feature(int64_list)\n",
    "        features[\"float_list\"] = create_float_feature(float_list)\n",
    "\n",
    "        tf_example = tf.train.Example(features=tf.train.Features(feature=features))\n",
    "        writer.write(tf_example.SerializeToString())\n",
    "        example_count += 1\n",
    "    writer.close()\n",
    "    print(\"Write {} rows in tfrecord.\".format(example_count))\n",
    "\n",
    "generate_tfrecord()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**参数说明：**\n",
    "- `TFRECORD_FILE`：TFRecord文件的路径。\n",
    "- `MINDRECORD_FILE`：输出的MindRecord文件路径。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3. 将TFRecord转换成MindRecord。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_dict = {\"file_name\": tf.io.FixedLenFeature([], tf.string),\n",
    "                \"image_bytes\": tf.io.FixedLenFeature([], tf.string),\n",
    "                \"int64_scalar\": tf.io.FixedLenFeature([], tf.int64),\n",
    "                \"float_scalar\": tf.io.FixedLenFeature([], tf.float32),\n",
    "                \"int64_list\": tf.io.FixedLenFeature([6], tf.int64),\n",
    "                \"float_list\": tf.io.FixedLenFeature([7], tf.float32),\n",
    "                }\n",
    "\n",
    "if os.path.exists(MINDRECORD_FILE):\n",
    "    os.remove(MINDRECORD_FILE)\n",
    "    os.remove(MINDRECORD_FILE + \".db\")\n",
    "\n",
    "tfrecord_transformer = TFRecordToMR(TFRECORD_FILE, MINDRECORD_FILE, feature_dict, [\"image_bytes\"])\n",
    "tfrecord_transformer.transform()\n",
    "\n",
    "assert os.path.exists(MINDRECORD_FILE)\n",
    "assert os.path.exists(MINDRECORD_FILE + \".db\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4. 通过`MindDataset`读取MindRecord。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Got 10 samples\n"
     ]
    }
   ],
   "source": [
    "data_set = ds.MindDataset(dataset_file=MINDRECORD_FILE)\n",
    "decode_op = vision.Decode()\n",
    "data_set = data_set.map(operations=decode_op, input_columns=[\"image_bytes\"], num_parallel_workers=2)\n",
    "count = 0\n",
    "for item in data_set.create_dict_iterator(output_numpy=True):\n",
    "    count += 1\n",
    "print(\"Got {} samples\".format(count))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore-1.0.1",
   "language": "python",
   "name": "mindspore-1.0.1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
